{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入所需包\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.model_selection import cross_val_score\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>total_magnetization_per_atom</th>\n",
       "      <th>MagpieData mode Number</th>\n",
       "      <th>MagpieData minimum MendeleevNumber</th>\n",
       "      <th>MagpieData avg_dev MendeleevNumber</th>\n",
       "      <th>MagpieData mode MendeleevNumber</th>\n",
       "      <th>MagpieData mode AtomicWeight</th>\n",
       "      <th>MagpieData minimum Electronegativity</th>\n",
       "      <th>MagpieData mode Electronegativity</th>\n",
       "      <th>MagpieData avg_dev NdValence</th>\n",
       "      <th>MagpieData range NfValence</th>\n",
       "      <th>...</th>\n",
       "      <th>MagpieData minimum NValence</th>\n",
       "      <th>MagpieData mean NdUnfilled</th>\n",
       "      <th>MagpieData avg_dev NdUnfilled</th>\n",
       "      <th>MagpieData maximum NfUnfilled</th>\n",
       "      <th>MagpieData range NfUnfilled</th>\n",
       "      <th>MagpieData mean NfUnfilled</th>\n",
       "      <th>MagpieData avg_dev NfUnfilled</th>\n",
       "      <th>MagpieData mode NfUnfilled</th>\n",
       "      <th>MagpieData mean GSmagmom</th>\n",
       "      <th>MagpieData avg_dev GSmagmom</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.225675</td>\n",
       "      <td>16.0</td>\n",
       "      <td>43.0</td>\n",
       "      <td>21.551020</td>\n",
       "      <td>88.0</td>\n",
       "      <td>32.065000</td>\n",
       "      <td>1.540</td>\n",
       "      <td>2.580</td>\n",
       "      <td>1.142857</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.285714</td>\n",
       "      <td>3.755102</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000006</td>\n",
       "      <td>0.000009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.100454</td>\n",
       "      <td>9.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>38.520000</td>\n",
       "      <td>93.0</td>\n",
       "      <td>18.998403</td>\n",
       "      <td>0.820</td>\n",
       "      <td>3.980</td>\n",
       "      <td>0.360000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>1.440000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000002</td>\n",
       "      <td>0.000004</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2.714286</td>\n",
       "      <td>8.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>30.367347</td>\n",
       "      <td>87.0</td>\n",
       "      <td>15.999400</td>\n",
       "      <td>1.185</td>\n",
       "      <td>3.440</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>7.0</td>\n",
       "      <td>...</td>\n",
       "      <td>6.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>7.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>3.428571</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.894973</td>\n",
       "      <td>16.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>33.506925</td>\n",
       "      <td>88.0</td>\n",
       "      <td>32.065000</td>\n",
       "      <td>0.790</td>\n",
       "      <td>2.580</td>\n",
       "      <td>1.994460</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.842105</td>\n",
       "      <td>1.329640</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.444350</td>\n",
       "      <td>0.701605</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3.899876</td>\n",
       "      <td>63.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>18.765432</td>\n",
       "      <td>25.0</td>\n",
       "      <td>151.964000</td>\n",
       "      <td>1.185</td>\n",
       "      <td>1.185</td>\n",
       "      <td>4.444444</td>\n",
       "      <td>7.0</td>\n",
       "      <td>...</td>\n",
       "      <td>9.0</td>\n",
       "      <td>0.444444</td>\n",
       "      <td>0.493827</td>\n",
       "      <td>7.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.888889</td>\n",
       "      <td>3.456790</td>\n",
       "      <td>7.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 21 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   total_magnetization_per_atom  MagpieData mode Number  \\\n",
       "0                      0.225675                    16.0   \n",
       "1                      0.100454                     9.0   \n",
       "2                      2.714286                     8.0   \n",
       "3                      0.894973                    16.0   \n",
       "4                      3.899876                    63.0   \n",
       "\n",
       "   MagpieData minimum MendeleevNumber  MagpieData avg_dev MendeleevNumber  \\\n",
       "0                                43.0                           21.551020   \n",
       "1                                 2.0                           38.520000   \n",
       "2                                25.0                           30.367347   \n",
       "3                                 5.0                           33.506925   \n",
       "4                                25.0                           18.765432   \n",
       "\n",
       "   MagpieData mode MendeleevNumber  MagpieData mode AtomicWeight  \\\n",
       "0                             88.0                     32.065000   \n",
       "1                             93.0                     18.998403   \n",
       "2                             87.0                     15.999400   \n",
       "3                             88.0                     32.065000   \n",
       "4                             25.0                    151.964000   \n",
       "\n",
       "   MagpieData minimum Electronegativity  MagpieData mode Electronegativity  \\\n",
       "0                                 1.540                              2.580   \n",
       "1                                 0.820                              3.980   \n",
       "2                                 1.185                              3.440   \n",
       "3                                 0.790                              2.580   \n",
       "4                                 1.185                              1.185   \n",
       "\n",
       "   MagpieData avg_dev NdValence  MagpieData range NfValence  ...  \\\n",
       "0                      1.142857                         0.0  ...   \n",
       "1                      0.360000                         0.0  ...   \n",
       "2                      0.000000                         7.0  ...   \n",
       "3                      1.994460                         0.0  ...   \n",
       "4                      4.444444                         7.0  ...   \n",
       "\n",
       "   MagpieData minimum NValence  MagpieData mean NdUnfilled  \\\n",
       "0                          4.0                    3.285714   \n",
       "1                          1.0                    0.800000   \n",
       "2                          6.0                    0.000000   \n",
       "3                          1.0                    0.842105   \n",
       "4                          9.0                    0.444444   \n",
       "\n",
       "   MagpieData avg_dev NdUnfilled  MagpieData maximum NfUnfilled  \\\n",
       "0                       3.755102                            0.0   \n",
       "1                       1.440000                            0.0   \n",
       "2                       0.000000                            7.0   \n",
       "3                       1.329640                            0.0   \n",
       "4                       0.493827                            7.0   \n",
       "\n",
       "   MagpieData range NfUnfilled  MagpieData mean NfUnfilled  \\\n",
       "0                          0.0                    0.000000   \n",
       "1                          0.0                    0.000000   \n",
       "2                          7.0                    3.000000   \n",
       "3                          0.0                    0.000000   \n",
       "4                          7.0                    3.888889   \n",
       "\n",
       "   MagpieData avg_dev NfUnfilled  MagpieData mode NfUnfilled  \\\n",
       "0                       0.000000                         0.0   \n",
       "1                       0.000000                         0.0   \n",
       "2                       3.428571                         0.0   \n",
       "3                       0.000000                         0.0   \n",
       "4                       3.456790                         7.0   \n",
       "\n",
       "   MagpieData mean GSmagmom  MagpieData avg_dev GSmagmom  \n",
       "0                  0.000006                     0.000009  \n",
       "1                  0.000002                     0.000004  \n",
       "2                  0.000000                     0.000000  \n",
       "3                  0.444350                     0.701605  \n",
       "4                  0.000000                     0.000000  \n",
       "\n",
       "[5 rows x 21 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 导入材料数据\n",
    "def load_data(csv_file): \n",
    "    return pd.read_csv(csv_file, encoding = 'utf-8')\n",
    "data_form = load_data(\"rfr_data_set_after_two_step_features_selection.csv文件所在的位置\")\n",
    "data_form.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# random_state : 机器学习结果重现设置值\n",
    "data_train, data_test = train_test_split(\n",
    "    data_form, \n",
    "    test_size = 0.2, \n",
    "    shuffle = True, \n",
    "    random_state = 20210606\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 提取出无机铁磁性材料训练数据集中的特征矩阵\n",
    "X_train = data_train.drop(['total_magnetization_per_atom'], axis = 1)\n",
    "# 提取出无机铁磁性材料训练数据集中的磁矩\n",
    "y_train = data_train['total_magnetization_per_atom']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>MagpieData mode Number</th>\n",
       "      <th>MagpieData minimum MendeleevNumber</th>\n",
       "      <th>MagpieData avg_dev MendeleevNumber</th>\n",
       "      <th>MagpieData mode MendeleevNumber</th>\n",
       "      <th>MagpieData mode AtomicWeight</th>\n",
       "      <th>MagpieData minimum Electronegativity</th>\n",
       "      <th>MagpieData mode Electronegativity</th>\n",
       "      <th>MagpieData avg_dev NdValence</th>\n",
       "      <th>MagpieData range NfValence</th>\n",
       "      <th>MagpieData avg_dev NfValence</th>\n",
       "      <th>MagpieData minimum NValence</th>\n",
       "      <th>MagpieData mean NdUnfilled</th>\n",
       "      <th>MagpieData avg_dev NdUnfilled</th>\n",
       "      <th>MagpieData maximum NfUnfilled</th>\n",
       "      <th>MagpieData range NfUnfilled</th>\n",
       "      <th>MagpieData mean NfUnfilled</th>\n",
       "      <th>MagpieData avg_dev NfUnfilled</th>\n",
       "      <th>MagpieData mode NfUnfilled</th>\n",
       "      <th>MagpieData mean GSmagmom</th>\n",
       "      <th>MagpieData avg_dev GSmagmom</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>22657</th>\n",
       "      <td>8.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>25.420000</td>\n",
       "      <td>87.0</td>\n",
       "      <td>15.999400</td>\n",
       "      <td>0.820</td>\n",
       "      <td>3.440</td>\n",
       "      <td>2.240000</td>\n",
       "      <td>14.0</td>\n",
       "      <td>4.900000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.600000</td>\n",
       "      <td>1.080000</td>\n",
       "      <td>7.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>0.700000</td>\n",
       "      <td>1.260000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27954</th>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>30.888889</td>\n",
       "      <td>93.0</td>\n",
       "      <td>18.998403</td>\n",
       "      <td>0.980</td>\n",
       "      <td>3.980</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.166667</td>\n",
       "      <td>1.944444</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14828</th>\n",
       "      <td>8.0</td>\n",
       "      <td>43.0</td>\n",
       "      <td>11.880000</td>\n",
       "      <td>87.0</td>\n",
       "      <td>15.999400</td>\n",
       "      <td>1.540</td>\n",
       "      <td>3.440</td>\n",
       "      <td>3.520000</td>\n",
       "      <td>14.0</td>\n",
       "      <td>4.480000</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1.200000</td>\n",
       "      <td>1.920000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.211069</td>\n",
       "      <td>0.379919</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3454</th>\n",
       "      <td>17.0</td>\n",
       "      <td>55.0</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>94.0</td>\n",
       "      <td>35.453000</td>\n",
       "      <td>1.830</td>\n",
       "      <td>3.160</td>\n",
       "      <td>1.080000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.720000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.211066</td>\n",
       "      <td>0.379919</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11436</th>\n",
       "      <td>17.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>31.920000</td>\n",
       "      <td>94.0</td>\n",
       "      <td>35.453000</td>\n",
       "      <td>0.820</td>\n",
       "      <td>3.160</td>\n",
       "      <td>1.600000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.600000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000062</td>\n",
       "      <td>0.000099</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10560</th>\n",
       "      <td>29.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>19.555556</td>\n",
       "      <td>25.0</td>\n",
       "      <td>63.546000</td>\n",
       "      <td>1.185</td>\n",
       "      <td>1.185</td>\n",
       "      <td>4.444444</td>\n",
       "      <td>7.0</td>\n",
       "      <td>3.111111</td>\n",
       "      <td>9.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>7.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>2.333333</td>\n",
       "      <td>3.111111</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13685</th>\n",
       "      <td>8.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>22.302021</td>\n",
       "      <td>87.0</td>\n",
       "      <td>15.999400</td>\n",
       "      <td>0.980</td>\n",
       "      <td>3.440</td>\n",
       "      <td>0.770511</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.275862</td>\n",
       "      <td>0.513674</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.145563</td>\n",
       "      <td>0.271048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3591</th>\n",
       "      <td>6.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>18.000000</td>\n",
       "      <td>77.0</td>\n",
       "      <td>12.010700</td>\n",
       "      <td>1.200</td>\n",
       "      <td>2.550</td>\n",
       "      <td>2.125000</td>\n",
       "      <td>7.0</td>\n",
       "      <td>2.625000</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.250000</td>\n",
       "      <td>3.250000</td>\n",
       "      <td>7.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>1.750000</td>\n",
       "      <td>2.625000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.527666</td>\n",
       "      <td>0.791499</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12857</th>\n",
       "      <td>27.0</td>\n",
       "      <td>47.0</td>\n",
       "      <td>5.500000</td>\n",
       "      <td>47.0</td>\n",
       "      <td>58.933195</td>\n",
       "      <td>1.600</td>\n",
       "      <td>1.600</td>\n",
       "      <td>1.500000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.500000</td>\n",
       "      <td>1.500000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.774236</td>\n",
       "      <td>0.774236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18453</th>\n",
       "      <td>8.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>31.843750</td>\n",
       "      <td>87.0</td>\n",
       "      <td>15.999400</td>\n",
       "      <td>0.980</td>\n",
       "      <td>3.440</td>\n",
       "      <td>1.708984</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.093750</td>\n",
       "      <td>1.708984</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000068</td>\n",
       "      <td>0.000106</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>25798 rows × 20 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       MagpieData mode Number  MagpieData minimum MendeleevNumber  \\\n",
       "22657                     8.0                                 3.0   \n",
       "27954                     9.0                                 1.0   \n",
       "14828                     8.0                                43.0   \n",
       "3454                     17.0                                55.0   \n",
       "11436                    17.0                                 3.0   \n",
       "...                       ...                                 ...   \n",
       "10560                    29.0                                25.0   \n",
       "13685                     8.0                                 1.0   \n",
       "3591                      6.0                                27.0   \n",
       "12857                    27.0                                47.0   \n",
       "18453                     8.0                                 1.0   \n",
       "\n",
       "       MagpieData avg_dev MendeleevNumber  MagpieData mode MendeleevNumber  \\\n",
       "22657                           25.420000                             87.0   \n",
       "27954                           30.888889                             93.0   \n",
       "14828                           11.880000                             87.0   \n",
       "3454                             8.000000                             94.0   \n",
       "11436                           31.920000                             94.0   \n",
       "...                                   ...                              ...   \n",
       "10560                           19.555556                             25.0   \n",
       "13685                           22.302021                             87.0   \n",
       "3591                            18.000000                             77.0   \n",
       "12857                            5.500000                             47.0   \n",
       "18453                           31.843750                             87.0   \n",
       "\n",
       "       MagpieData mode AtomicWeight  MagpieData minimum Electronegativity  \\\n",
       "22657                     15.999400                                 0.820   \n",
       "27954                     18.998403                                 0.980   \n",
       "14828                     15.999400                                 1.540   \n",
       "3454                      35.453000                                 1.830   \n",
       "11436                     35.453000                                 0.820   \n",
       "...                             ...                                   ...   \n",
       "10560                     63.546000                                 1.185   \n",
       "13685                     15.999400                                 0.980   \n",
       "3591                      12.010700                                 1.200   \n",
       "12857                     58.933195                                 1.600   \n",
       "18453                     15.999400                                 0.980   \n",
       "\n",
       "       MagpieData mode Electronegativity  MagpieData avg_dev NdValence  \\\n",
       "22657                              3.440                      2.240000   \n",
       "27954                              3.980                      0.833333   \n",
       "14828                              3.440                      3.520000   \n",
       "3454                               3.160                      1.080000   \n",
       "11436                              3.160                      1.600000   \n",
       "...                                  ...                           ...   \n",
       "10560                              1.185                      4.444444   \n",
       "13685                              3.440                      0.770511   \n",
       "3591                               2.550                      2.125000   \n",
       "12857                              1.600                      1.500000   \n",
       "18453                              3.440                      1.708984   \n",
       "\n",
       "       MagpieData range NfValence  MagpieData avg_dev NfValence  \\\n",
       "22657                        14.0                      4.900000   \n",
       "27954                         0.0                      0.000000   \n",
       "14828                        14.0                      4.480000   \n",
       "3454                          0.0                      0.000000   \n",
       "11436                         0.0                      0.000000   \n",
       "...                           ...                           ...   \n",
       "10560                         7.0                      3.111111   \n",
       "13685                         0.0                      0.000000   \n",
       "3591                          7.0                      2.625000   \n",
       "12857                         0.0                      0.000000   \n",
       "18453                         0.0                      0.000000   \n",
       "\n",
       "       MagpieData minimum NValence  MagpieData mean NdUnfilled  \\\n",
       "22657                          1.0                    0.600000   \n",
       "27954                          1.0                    1.166667   \n",
       "14828                          4.0                    1.200000   \n",
       "3454                           5.0                    0.400000   \n",
       "11436                          1.0                    1.000000   \n",
       "...                            ...                         ...   \n",
       "10560                          9.0                    0.000000   \n",
       "13685                          1.0                    0.275862   \n",
       "3591                           4.0                    3.250000   \n",
       "12857                          5.0                    4.500000   \n",
       "18453                          1.0                    1.093750   \n",
       "\n",
       "       MagpieData avg_dev NdUnfilled  MagpieData maximum NfUnfilled  \\\n",
       "22657                       1.080000                            7.0   \n",
       "27954                       1.944444                            0.0   \n",
       "14828                       1.920000                            0.0   \n",
       "3454                        0.720000                            0.0   \n",
       "11436                       1.600000                            0.0   \n",
       "...                              ...                            ...   \n",
       "10560                       0.000000                            7.0   \n",
       "13685                       0.513674                            0.0   \n",
       "3591                        3.250000                            7.0   \n",
       "12857                       1.500000                            0.0   \n",
       "18453                       1.708984                            0.0   \n",
       "\n",
       "       MagpieData range NfUnfilled  MagpieData mean NfUnfilled  \\\n",
       "22657                          7.0                    0.700000   \n",
       "27954                          0.0                    0.000000   \n",
       "14828                          0.0                    0.000000   \n",
       "3454                           0.0                    0.000000   \n",
       "11436                          0.0                    0.000000   \n",
       "...                            ...                         ...   \n",
       "10560                          7.0                    2.333333   \n",
       "13685                          0.0                    0.000000   \n",
       "3591                           7.0                    1.750000   \n",
       "12857                          0.0                    0.000000   \n",
       "18453                          0.0                    0.000000   \n",
       "\n",
       "       MagpieData avg_dev NfUnfilled  MagpieData mode NfUnfilled  \\\n",
       "22657                       1.260000                         0.0   \n",
       "27954                       0.000000                         0.0   \n",
       "14828                       0.000000                         0.0   \n",
       "3454                        0.000000                         0.0   \n",
       "11436                       0.000000                         0.0   \n",
       "...                              ...                         ...   \n",
       "10560                       3.111111                         0.0   \n",
       "13685                       0.000000                         0.0   \n",
       "3591                        2.625000                         0.0   \n",
       "12857                       0.000000                         0.0   \n",
       "18453                       0.000000                         0.0   \n",
       "\n",
       "       MagpieData mean GSmagmom  MagpieData avg_dev GSmagmom  \n",
       "22657                  0.000000                     0.000000  \n",
       "27954                  0.000000                     0.000000  \n",
       "14828                  0.211069                     0.379919  \n",
       "3454                   0.211066                     0.379919  \n",
       "11436                  0.000062                     0.000099  \n",
       "...                         ...                          ...  \n",
       "10560                  0.000000                     0.000000  \n",
       "13685                  0.145563                     0.271048  \n",
       "3591                   0.527666                     0.791499  \n",
       "12857                  0.774236                     0.774236  \n",
       "18453                  0.000068                     0.000106  \n",
       "\n",
       "[25798 rows x 20 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看X_train\n",
    "X_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "22657    0.700000\n",
       "27954    0.333334\n",
       "14828    0.500002\n",
       "3454     0.497234\n",
       "11436    1.000090\n",
       "           ...   \n",
       "10560    2.361337\n",
       "13685    0.310349\n",
       "3591     1.981722\n",
       "12857    0.163378\n",
       "18453    0.812417\n",
       "Name: total_magnetization_per_atom, Length: 25798, dtype: float64"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看y_train\n",
    "y_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 提取出无机铁磁性材料测试数据集中的特征矩阵\n",
    "X_test = data_test.drop(['total_magnetization_per_atom'], axis = 1)\n",
    "# 提取出无机铁磁性材料测试数据集中的磁矩\n",
    "y_test = data_test['total_magnetization_per_atom']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建RFR无机铁磁性材料磁矩预测模型\n",
    "rfr_pipeline = Pipeline([\n",
    "    # 缺失值填充\n",
    "    ('imputer', SimpleImputer(missing_values=np.nan, strategy='mean')),\n",
    "    # RFC模型\n",
    "    ('rfc', RandomForestRegressor(n_estimators = 300, \n",
    "                                   max_features = 'auto', \n",
    "                                   min_samples_leaf = 1, \n",
    "                                   min_samples_split = 2, \n",
    "                                   n_jobs = -1)\n",
    "    )\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "         steps=[('imputer',\n",
       "                 SimpleImputer(add_indicator=False, copy=True, fill_value=None,\n",
       "                               missing_values=nan, strategy='mean',\n",
       "                               verbose=0)),\n",
       "                ('rfc',\n",
       "                 RandomForestRegressor(bootstrap=True, ccp_alpha=0.0,\n",
       "                                       criterion='mse', max_depth=None,\n",
       "                                       max_features='auto', max_leaf_nodes=None,\n",
       "                                       max_samples=None,\n",
       "                                       min_impurity_decrease=0.0,\n",
       "                                       min_impurity_split=None,\n",
       "                                       min_samples_leaf=1, min_samples_split=2,\n",
       "                                       min_weight_fraction_leaf=0.0,\n",
       "                                       n_estimators=300, n_jobs=-1,\n",
       "                                       oob_score=False, random_state=None,\n",
       "                                       verbose=0, warm_start=False))],\n",
       "         verbose=False)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 构建模型\n",
    "rfr_pipeline.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 磁矩预测训练结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "磁矩预测模型在训练集上的10折交叉验证结果：\n",
      "-----------------------------------------------------------------------\n",
      "拟合优度的每折结果：\n",
      "[0.9401572  0.94141813 0.93853005 0.94702985 0.9527148  0.95690168\n",
      " 0.94406655 0.95053114 0.95474861 0.94620879]\n",
      "拟合优度的平均值：\n",
      "0    0.947231\n",
      "dtype: float64\n",
      "-----------------------------------------------------------------------\n",
      "平均绝对误差的每折结果：\n",
      "[0.0684056517028755, 0.0661447795919321, 0.07017578751235244, 0.06695395458362276, 0.07005867779876748, 0.0677850109565217, 0.0688750494389076, 0.06426092715398109, 0.06770515421862372, 0.0690966342576126]\n",
      "平均绝对误差的平均值：\n",
      "0    0.067946\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "# 磁矩预测模型在训练集上的10折交叉验证结果\n",
    "print(\"磁矩预测模型在训练集上的10折交叉验证结果：\")\n",
    "print(\"-----------------------------------------------------------------------\")\n",
    "# 拟合优度\n",
    "cross_val_score_r2 = cross_val_score(rfr_pipeline, X_train, y_train, scoring = 'r2', cv=10)\n",
    "# 拟合优度的每折结果\n",
    "print(\"拟合优度的每折结果：\")\n",
    "print(cross_val_score_r2)\n",
    "# 拟合优度的平均值\n",
    "print(\"拟合优度的平均值：\")\n",
    "print(pd.DataFrame(cross_val_score_r2).mean())\n",
    "print(\"-----------------------------------------------------------------------\")\n",
    "\n",
    "# 平均绝对误差\n",
    "cross_val_score_mae = list(map(abs, cross_val_score(rfr_pipeline, X_train, y_train, scoring = 'neg_mean_absolute_error', cv=10)))\n",
    "print(\"平均绝对误差的每折结果：\")\n",
    "# 平均绝对误差的每折结果\n",
    "print(cross_val_score_mae)\n",
    "print(\"平均绝对误差的平均值：\")\n",
    "# 平均绝对误差的平均值\n",
    "print(pd.DataFrame(cross_val_score_mae).mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 磁矩预测测试结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "磁矩预测模型在测试集上的10折交叉验证结果：\n",
      "-----------------------------------------------------------------------\n",
      "拟合优度的每折结果：\n",
      "[0.91425251 0.92515006 0.91279081 0.91206334 0.90919661 0.93884799\n",
      " 0.91424022 0.90852987 0.90613032 0.91883311]\n",
      "拟合优度的平均值：\n",
      "0    0.916003\n",
      "dtype: float64\n",
      "-----------------------------------------------------------------------\n",
      "平均绝对误差的每折结果：\n",
      "[0.10339843548060554, 0.09481695408594536, 0.09653563500884588, 0.09728711417245371, 0.09638881049613005, 0.098974935836454, 0.0923408202235827, 0.10280256904349672, 0.10216082189031346, 0.09110847169870732]\n",
      "平均绝对误差的平均值：\n",
      "0    0.097581\n",
      "dtype: float64\n"
     ]
    }
   ],
   "source": [
    "# 磁矩预测模型在测试集上的10折交叉验证结果\n",
    "print(\"磁矩预测模型在测试集上的10折交叉验证结果：\")\n",
    "print(\"-----------------------------------------------------------------------\")\n",
    "# 拟合优度\n",
    "cross_val_score_r2 = cross_val_score(rfr_pipeline, X_test, y_test, scoring = 'r2', cv=10)\n",
    "# 拟合优度的每折结果\n",
    "print(\"拟合优度的每折结果：\")\n",
    "print(cross_val_score_r2)\n",
    "# 拟合优度的平均值\n",
    "print(\"拟合优度的平均值：\")\n",
    "print(pd.DataFrame(cross_val_score_r2).mean())\n",
    "print(\"-----------------------------------------------------------------------\")\n",
    "\n",
    "# 平均绝对误差\n",
    "cross_val_score_mae = list(map(abs, cross_val_score(rfr_pipeline, X_test, y_test, scoring = 'neg_mean_absolute_error', cv=10)))\n",
    "print(\"平均绝对误差的每折结果：\")\n",
    "# 平均绝对误差的每折结果\n",
    "print(cross_val_score_mae)\n",
    "print(\"平均绝对误差的平均值：\")\n",
    "# 平均绝对误差的平均值\n",
    "print(pd.DataFrame(cross_val_score_mae).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
